import tensorflow as tf

"""
反卷积的原理
反卷积可以理解为卷积操作的逆操作，但是不要认为它能够复原出卷积操作的输入值
它仅仅是将卷积变换的过程中的步骤反向变换一次而已，通过将卷积核转置(并不是矩阵的转置)
与卷积后得结果再做一遍卷积，所以它还叫转置卷积。虽然其不能还原出原来卷积的样子，
但还是有类似的效果，只是将小部分缺失信息最大化的恢复
应用
一般可以用来信道均衡、图像恢复、语音识别、地震学、无损探伤等未知输入估计和过程辨识方面的问题
更多就是充当可视化的作用
反卷积的操作步骤：
1）首先将卷积核反转(并不是转置，而是上下左右方向进行递序操作)
2）再将卷积结果作为输入，做补0的扩充操作，即往每一个元素后面补0，这一步是根据步长来的，
   对每一个元素沿着步长的方向补(步长-1)个0，
3）在扩充后的输入基础上再对整体补0。以原始输入的shape作为输出，按照卷积padding规则，
  计算padding的补0位置及个数(统一按照padding='SAME', 步长为1)，得到的补0位置要上下和左右各自颠倒一下
4）将补0后得卷积结果作为真正的输入，反转后的卷积核为filter，进行步长为1的卷积操作

TensorFlow中的操作函数：
def conv2d_transpose(
    value,                # 代表通过卷积操作之后的Tensor，一般用 NHWC 类型
    filter,               # 卷积核
    output_shape,         # 输出的Tensor形状也是个4维Tensor，value参数的原数据形状
    strides,              # 步长
    padding="SAME",       # 代表原数据生成value时使用的补0方式，是用来检查输入形状和输出形状是否合规的
    data_format="NHWC",   #  神经网络中在图像处理方面常用的类型，N--个数，H--高，W--宽，C--通道数
    name=None)

    在TensorFlow源码中，反卷积操作其实是使用 gen_nn_ops.conv2d_backprop_input 函数最终实现
    即 TensorFlow 中利用了卷积操作在反向传播的处理函数中做反卷积操作，即卷及操作的反向传播就是反卷积操作
"""

"""
实例，并比较卷积与反卷积中padding在SAME和VALID下的变化
"""

# 模拟数据
img = tf.Variable(tf.constant(1., shape=[1, 4, 4, 1]))
filter_ = tf.Variable(tf.constant([1., 0, -1, -2], shape=[2, 2, 1, 1]))

con_s = tf.nn.conv2d(img, filter_, strides=[1, 2, 2, 1], padding='SAME')
con_v = tf.nn.conv2d(img, filter_, strides=[1, 2, 2, 1], padding='VALID')

print(con_s.shape)
print(con_v.shape)

# 反卷积
con_ts = tf.nn.conv2d_transpose(con_v, filter_, output_shape=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
con_tv = tf.nn.conv2d_transpose(con_v, filter_, output_shape=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print('con_s:\n', sess.run([con_s, filter_]))
    print('con_v:\n', sess.run([con_v]))
    print('con_ts:\n', sess.run([con_ts]))
    print('con_tv:\n', sess.run([con_tv]))

# con_s:
#  [
#   array([[
#         [[-2.],[-2.]],
#         [[-2.],[-2.]]
#      ]], dtype=float32),
#   array([
#          [
#           [[ 1.]],[[ 0.]]
#          ],
#          [
#           [[-1.]],[[-2.]]
#          ]
#         ], dtype=float32)
#   ]
# con_v:
#  [array([[
#           [[-2.],[-2.]],
#           [[-2.],[-2.]]
#         ]], dtype=float32)]
# con_ts:
#  [array([[
#           [[-2.],[ 0.],[-2.],[ 0.]],
#           [[ 2.],[ 4.],[ 2.],[ 4.]],
#           [[-2.],[ 0.],[-2.],[ 0.]],
#           [[ 2.],[ 4.],[ 2.],[ 4.]]
#         ]], dtype=float32)]
# con_tv:
#  [array([[
#           [[-2.],[ 0.],[-2.],[ 0.]],
#           [[ 2.],[ 4.],[ 2.],[ 4.]],
#           [[-2.],[ 0.],[-2.],[ 0.]],
#           [[ 2.],[ 4.],[ 2.],[ 4.]]
#         ]], dtype=float32)]

